-
Notifications
You must be signed in to change notification settings - Fork 37
Refactor tutorial to use dataclass for configuration #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi @Aatman09 . Thank you for the nice commit. Could you please include a few pip installs at the beginning of the notebook for additional dependencies. Please also include their versions. e.g. |
|
For the KV cache, this would be nice to add in the
I think option 2 makes the most sense for this tutorial so it doesn't get too in the weeds on the cache. Implementing your own caching may also require writing your own attention layers. For more details, the nnx docs cover how to initialize a cache (https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html). This can be done with |
|
Thank you for the review I will implement the changes as soon as possible |
|
I reduced the number of epochs from 10 to 2 for testing, which is why the graph looks different. |
|
Hi @Aatman09 . I see the issue. I'll make some quick general comments here.
|
|
For the decode sequence function, it should look closer to this: def decode_sequence(input_sentence):
input_sentence = custom_standardization(input_sentence)
tokenized_input_sentence = tokenize_and_pad(input_sentence, tokenizer, sequence_length)
encoder_input = jnp.array([tokenized_input_sentence])
emb_enc = model.positional_embedding(encoder_input)
encoder_outputs = model.encoder(emb_enc, mask=None)
dummy_input_shape = (1, 30, model.config.embed_dim) # <- Update the cache size to be sufficiently large. I chose 30 here
model.init_cache(dummy_input_shape)
decoded_sentence = "[start"
current_token_id = tokenizer.encode("[start")
current_input = jnp.array([current_token_id])
for i in range(sequence_length):
logits = model.decode_step(current_input, encoder_outputs, step_index=i)
sampled_id = np.argmax(logits[0, 0, :]).item()
sampled_token = tokenizer.decode([sampled_id])
decoded_sentence += "" + sampled_token # Your implementation had a space here, but this should be an empty string
if sampled_token == "[end]":
break
# Update input for next loop
current_input = jnp.array([[sampled_id]])
return decoded_sentenceThe main issue is that the kv-cache is too small (only 1 token). With jax, out-of-bounds indexing above the max index will just reduce to the last index resulting in a silent bug. Make sure that the kv-cache is large enough to handle the decoding. I chose 30 here, but this should be decided programmatically (based on the input). English and spanish are tokenized differently so the exact number of output tokens isn't fully clear. Using a multiple of the number of input tokens should suffice (e.g. 2 or 3 times as many tokens for the cache). Updating the Finally, you'll notice that these changes result in many |
|
Also, for the pip installs, using the |
* Refactor tutorial to use dataclass for configuration * Imeplemented KV caching (WIP) * final changes
Resolves #107
Reference
This implementation is based on the following tutorial:
JAX Machine Translation Tutorial
Changes made
Notes
Checklist
run_model.pyfor model usage,test_outputs.pyand/ormodel_validation_colab.ipynbfor quality).